{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Serving a custom model with JSON serialization\n",
    "\n",
    "The `mlserver` package comes with inference runtime implementations for `scikit-learn` and `xgboost` models.\n",
    "However, some times we may also need to roll out our own inference server, with custom logic to perform inference.\n",
    "To support this scenario, MLServer makes it really easy to create your own extensions, which can then be containerised and deployed in a production environment."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Overview\n",
    "\n",
    "In this example, we create a simple `Hello World JSON` model that parses and modifies a JSON data chunk. This is often useful as a means to\n",
    "quickly bootstrap existing models that utilize JSON based model inputs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Serving\n",
    "\n",
    "The next step will be to serve our model using `mlserver`. \n",
    "For that, we will first implement an extension which serve as the _runtime_ to perform inference using our custom `Hello World JSON` model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Custom inference runtime\n",
    "\n",
    "This is a trivial model to demonstrate how to conceptually work with JSON inputs / outputs. In this example:\n",
    "\n",
    "- Parse the JSON input from the client\n",
    "- Create a JSON response echoing back the client request as well as a server generated message\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile jsonmodels.py\n",
    "import json\n",
    "\n",
    "from typing import Dict, Any\n",
    "from mlserver import MLModel, types\n",
    "from mlserver.codecs import StringCodec\n",
    "\n",
    "\n",
    "class JsonHelloWorldModel(MLModel):\n",
    "    async def load(self) -> bool:\n",
    "        # Perform additional custom initialization here.\n",
    "        print(\"Initialize model\")\n",
    "\n",
    "        # Set readiness flag for model\n",
    "        return await super().load()\n",
    "\n",
    "    async def predict(self, payload: types.InferenceRequest) -> types.InferenceResponse:\n",
    "        request = self._extract_json(payload)\n",
    "        response = {\n",
    "            \"request\": request,\n",
    "            \"server_response\": \"Got your request. Hello from the server.\",\n",
    "        }\n",
    "        response_bytes = json.dumps(response).encode(\"UTF-8\")\n",
    "\n",
    "        return types.InferenceResponse(\n",
    "            id=payload.id,\n",
    "            model_name=self.name,\n",
    "            model_version=self.version,\n",
    "            outputs=[\n",
    "                types.ResponseOutput(\n",
    "                    name=\"echo_response\",\n",
    "                    shape=[len(response_bytes)],\n",
    "                    datatype=\"BYTES\",\n",
    "                    data=[response_bytes],\n",
    "                    parameters=types.Parameters(content_type=\"str\"),\n",
    "                )\n",
    "            ],\n",
    "        )\n",
    "\n",
    "    def _extract_json(self, payload: types.InferenceRequest) -> Dict[str, Any]:\n",
    "        inputs = {}\n",
    "        for inp in payload.inputs:\n",
    "            inputs[inp.name] = json.loads(\n",
    "                \"\".join(self.decode(inp, default_codec=StringCodec))\n",
    "            )\n",
    "\n",
    "        return inputs\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Settings files\n",
    "\n",
    "The next step will be to create 2 configuration files: \n",
    "\n",
    "- `settings.json`: holds the configuration of our server (e.g. ports, log level, etc.).\n",
    "- `model-settings.json`: holds the configuration of our model (e.g. input type, runtime to use, etc.)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `settings.json`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile settings.json\n",
    "{\n",
    "    \"debug\": \"true\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### `model-settings.json`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile model-settings.json\n",
    "{\n",
    "    \"name\": \"json-hello-world\",\n",
    "    \"implementation\": \"jsonmodels.JsonHelloWorldModel\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Start serving our model\n",
    "\n",
    "Now that we have our config in-place, we can start the server by running `mlserver start .`. This needs to either be ran from the same directory where our config files are or pointing to the folder where they are.\n",
    "\n",
    "```shell\n",
    "mlserver start .\n",
    "```\n",
    "\n",
    "Since this command will start the server and block the terminal, waiting for requests, this will need to be ran in the background on a separate terminal."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Send test inference request (REST)\n",
    "\n",
    "\n",
    "We now have our model being served by `mlserver`.\n",
    "To make sure that everything is working as expected, let's send a request from our test set.\n",
    "\n",
    "For that, we can use the Python types that `mlserver` provides out of box, or we can build our request manually."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import json\n",
    "\n",
    "inputs = {\n",
    "    \"name\": \"Foo Bar\",\n",
    "    \"message\": \"Hello from Client (REST)!\"\n",
    "}\n",
    "\n",
    "# NOTE: this uses characters rather than encoded bytes. It is recommended that you use the `mlserver` types to assist in the correct encoding.\n",
    "inputs_string= json.dumps(inputs)\n",
    "\n",
    "inference_request = {\n",
    "    \"inputs\": [\n",
    "        {\n",
    "            \"name\": \"echo_request\",\n",
    "            \"shape\": [len(inputs_string)],\n",
    "            \"datatype\": \"BYTES\",\n",
    "            \"data\": [inputs_string]\n",
    "        }\n",
    "    ]\n",
    "}\n",
    "\n",
    "endpoint = \"http://localhost:8080/v2/models/json-hello-world/infer\"\n",
    "response = requests.post(endpoint, json=inference_request)\n",
    "\n",
    "response.json()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Send test inference request (gRPC)\n",
    "\n",
    "\n",
    "Utilizing string data with the gRPC interface can be a bit tricky. To ensure we are correctly handling inputs and\n",
    "outputs we will be handled correctly.\n",
    "\n",
    "For simplicity in this case, we leverage the Python types that `mlserver` provides out of the box. Alternatively,\n",
    "the gRPC stubs can be generated regenerated from the V2 specification directly for use by non-Python as well as \n",
    "Python clients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import json\n",
    "import grpc\n",
    "import mlserver.grpc.converters as converters\n",
    "import mlserver.grpc.dataplane_pb2_grpc as dataplane\n",
    "import mlserver.types as types\n",
    "\n",
    "model_name = \"json-hello-world\"\n",
    "inputs = {\n",
    "    \"name\": \"Foo Bar\",\n",
    "    \"message\": \"Hello from Client (gRPC)!\"\n",
    "}\n",
    "inputs_bytes = json.dumps(inputs).encode(\"UTF-8\")\n",
    "\n",
    "inference_request = types.InferenceRequest(\n",
    "    inputs=[\n",
    "        types.RequestInput(\n",
    "            name=\"echo_request\",\n",
    "            shape=[len(inputs_bytes)],\n",
    "            datatype=\"BYTES\",\n",
    "            data=[inputs_bytes],\n",
    "            parameters=types.Parameters(content_type=\"str\")\n",
    "        )\n",
    "    ]\n",
    ")\n",
    "\n",
    "inference_request_g = converters.ModelInferRequestConverter.from_types(\n",
    "    inference_request,\n",
    "    model_name=model_name,\n",
    "    model_version=None\n",
    ")\n",
    "\n",
    "grpc_channel = grpc.insecure_channel(\"localhost:8081\")\n",
    "grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)\n",
    "\n",
    "response = grpc_stub.ModelInfer(inference_request_g)\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
